/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved.
 * Description: custom attr test
 */

#include "gtest/gtest.h"
#include "securec.h"
#include "graph/attr_value.h"
#include "graph/tensor.h"
#include "graph/shape.h"
#include "graph/types.h"

namespace ct {
namespace {
struct CustomAttrParam {
    int index;
};

class MyCustomAttr : public ge::IPipeCustomAttr {
public:
    void ToBuffer(void*& buffer, size_t& size) const
    {
        uint8_t*& data = (uint8_t*&)buffer;
        size = sizeof(index_);
        data = new (std::nothrow) uint8_t[size]();
        int ret = memcpy_s(data, size, &index_, sizeof(index_));
        if (ret != 0) {
            if (data != nullptr) {
                delete[] data;
                data = nullptr;
            }
            size = 0;
        }
    }

    bool FromBuffer(const void* buffer, size_t size)
    {
        int ret = memcpy_s(&index_, sizeof(index_), buffer, sizeof(index_));
        if (ret != 0) {
            return false;
        }
        return true;
    }

public:
    int index_ = 0;
};
} // namespace

class CustomAttrTest : public testing::TestWithParam<CustomAttrParam> {
public:
    static void SetUpTestCase()
    {
    }
    static void TearDownTestCase()
    {
    }

protected:
    virtual void SetUp()
    {
    }
    virtual void TearDown()
    {
    }
};

TEST_P(CustomAttrTest, CustomAttr)
{
    CustomAttrParam param = GetParam();

    ct::MyCustomAttr myType;
    myType.index_ = param.index;

    ge::AttrValue value;
    value.GetValue<ct::MyCustomAttr>();
    value = ge::AttrValue::CreateFrom<ct::MyCustomAttr>(myType);
    EXPECT_EQ(value.GetValue<ct::MyCustomAttr>().Value().index_, param.index);

    ge::Shape inputShape = ge::Shape({1, 1, 360, 640});
    ge::TensorDesc desc(inputShape, ge::DT_INT64);
    desc.SetAttr("customKey", ge::AttrValue::CreateFrom<ct::MyCustomAttr>(myType));
    ge::AttrValue valueRet;
    desc.GetAttr("customKey", valueRet);
    EXPECT_EQ(valueRet.GetValue<ct::MyCustomAttr>().Value().index_, param.index);
}

namespace {
static const CustomAttrParam custom_Attr_Success_001 = {
    .index = 301,
};

} // namespace

INSTANTIATE_TEST_CASE_P(CustomAttr, CustomAttrTest,
    testing::Values(
        custom_Attr_Success_001));
} // namespace ct